Frontmatter

If you are publishing this notebook on the web, you can set the parameters below to provide HTML metadata. This is useful for search engines and social media.

using LegibleLambdas, AbstractTrees, PlutoUI, HypertextLiteral, PlutoTest
1.0 s

Let's write our own reverse-mode AD!

md"## Let's write our own reverse-mode AD!"
279 μs

We will use Julia's dispatch system for simplicity. This means we create a type Tracked for keeping track of our input variables and everything we'll need to calculate the gradient later.

md"""
We will use Julia's dispatch system for simplicity. This means we create a type `Tracked` for keeping track of our input variables and everything we'll need to calculate the gradient later.
"""
223 μs
begin
struct Tracked{T} <: Number
# The numerical result when doing the forward pass
val::T
name::Symbol
# The pullback map for the reverse pass
df
# All the other variables this variable directly depends on
deps::Vector{Tracked}
end
Tracked{T}(x, name=gensym()) where {T} = Tracked{T}(x, name, nothing, Tracked[])
Base.convert(T::Type{Tracked{S}}, x::Tracked) where {S} = T(convert(S, x.val), x.name, x.df, x.deps)
# This tells Julia to convert any number added to a `Tracked` to a `Tracked` first
Base.promote_rule(::Type{Tracked{S}}, ::Type{T}) where {S<:Number, T<:Number} = Tracked{promote_type(S, T)}
end
6.7 ms

All overloads will do the operation (e.g. sum x and y), but also remember the pullback map and input variables for the reverse pass.

(from the LegibleLambdas.jl package) is just for the nicer printing, we could have replaced @λ(Δ -> (Δ, Δ)) with Δ -> (Δ, Δ) if we didn't care about that

md"""
All overloads will do the operation (e.g. sum `x` and `y`), but also remember the pullback map and input variables for the reverse pass.

`@λ` (from the *LegibleLambdas.jl* package) is just for the nicer printing, we could have replaced `@λ(Δ -> (Δ, Δ))` with `Δ -> (Δ, Δ)` if we didn't care about that
"""
355 μs
function Base.:+(x::Tracked, y::Tracked)
Tracked(x.val + y.val, :+, (Δ -> (Δ, Δ)), Tracked[x, y])
end
1.0 ms
function Base.:-(x::Tracked, y::Tracked)
Tracked(x.val - y.val, :-, (Δ -> (Δ, -Δ)), Tracked[x, y])
end
1.1 ms
function Base.:*(x::Tracked, y::Tracked)
Tracked(x.val * y.val, :*, (Δ -> (Δ * y.val, Δ * x.val)), Tracked[x, y])
end
8.1 ms
function Base.:/(x::Tracked, y::Tracked)
Tracked(x.val / y.val, :/, (Δ -> (Δ / y.val, -Δ * x.val / y.val^2)), Tracked[x, y])
end
1.2 ms
function Base.:^(x::Tracked, n::Int)
Tracked(x.val^n, Symbol("^$n"), (Δ -> (Δ * n * x.val^(n-1),)), Tracked[x,])
end
1.2 ms
function Base.sin(x::Tracked)
Tracked(sin(x.val), :sin, (Δ -> (Δ * cos(x.val),)), Tracked[x,])
end
858 μs
function Base.exp(x::Tracked)
Tracked(exp(x.val), :exp, (Δ -> (Δ * exp(x.val),)), Tracked[x,])
end
1.1 ms

Tracked is a tree – We just need to tell AbstractTrees.jl how to get the children for each node and we get tree printing and iteration over all nodes for free.

md"""
`Tracked` is a tree -- We just need to tell *AbstractTrees.jl* how to get the children for each node and we get tree printing and iteration over all nodes for free.
"""
286 μs
AbstractTrees.children(x::Tracked) = x.deps
341 μs

Let's also overload show for nicer output:

md"""
Let's also overload `show` for nicer output:
"""
226 μs
begin
# All this is just for nicer printing
function Base.show(io::IO, x::Tracked)
if x.df === nothing
print(io, Base.isgensym(x.name) ? x.val : "$(x.name)=$(x.val)")
else
print(io, "(")
show(io, x.val)
print(io, ", ")
print(io, x.name)
print(io, ")")
end
end
Base.show(io::IO, ::MIME"text/plain", x::Tracked) = print_tree(io, x)
end
1.4 ms

Create some variables we want to eventually differentiate with respect to.

md"""
Create some variables we want to eventually differentiate with respect to.
"""
229 μs
y=7
begin
x = Tracked{Int}(2, :x)
y = Tracked{Int}(7, :y)
end
7.3 ms

Straight away we get the primal result of our calculation:

md"""
Straight away we get the primal result of our calculation:
"""
305 μs
29
(2x*y + (x-1)^2).val # The result of `2x*y + (x-1)^2`
28.7 ms

To also get the gradient, we'll use PreOrderDFS to traverse the tree we just created from the top down.

md"""
To also get the gradient, we'll use `PreOrderDFS` to traverse the tree we just created from the top down.
"""
264 μs
z
(29, +)
├─ (28, *)
│  ├─ (4, *)
│  │  ├─ 2
│  │  └─ x=2
│  └─ y=7
└─ (1, ^2)
   └─ (1, -)
      ├─ x=2
      └─ 1
z = (2x*y + (x-1)^2)
51.4 μs
# `PreOrderDFS` traverses this tree from the top down
Text.(collect(PreOrderDFS(z)))
283 ms

Ok, let's create our function grad which will accumulate all intermediate gradients into a dictionary:

md"""
Ok, let's create our function `grad` which will accumulate all intermediate gradients into a dictionary:
"""
238 μs
grad (generic function with 1 method)
function grad(f::Tracked)
d = Dict{Any, Any}(f => 1)
for x in PreOrderDFS(f) # recursively traverse all dependents
x.df === nothing && continue # ignore untracked variables like constants
dy = x.df(d[x]) # evaluate pullback
for (yᵢ, dyᵢ) in zip(x.deps, dy)
# store the gradient in d
# if we have already stored a gradient for this variable, we need to add them
d[yᵢ] = get(d, yᵢ, 0) + dyᵢ
end
end
return d
end
1.9 ms
grad (generic function with 2 methods)
grad(f::Tracked, x::Tracked) = grad(f)[x]
450 μs

We can verify that it does the right thing:

md"""
We can verify that it does the right thing:
"""
219 μs
w
(16, +)
├─ (14, *)
│  ├─ x=2
│  └─ y=7
└─ x=2
w = x*y + x
24.9 μs
87.7 ms
grad(w, x), grad(w, y)
144 μs

How can we visualize both the forward and the reverse pass?

We can further visualize each steps we just took. First we do the forwards calculation, where we also build up our tree, then we go down the tree in the opposite direction to accumulate our gradient.

md"""
## How can we visualize both the forward and the reverse pass?

We can further visualize each steps we just took. First we do the forwards calculation, where we also build up our tree, then we go down the tree in the opposite direction to accumulate our gradient.
"""
383 μs
👽
👽=12
👽 = Tracked{Int}(12, :👽)
32.4 μs
ex
:(x * exp(-0.5 * (x ^ 2 + y ^ 2)))
#ex = :(3y*x + 2(x-1)*x)
# ex = :(x^3 + y + sin(👽))
ex = :(x*exp( (-.5)*(x^2+y^2)))
46.8 μs
steps = ad_steps(ex);
1.3 s

ArgumentError: invalid index: Main.PlutoRunner.Bond(PlutoUI.BuiltinsNotebook.Slider{Int64}(1:17, 1, false), :i, "GIudtJmLQe9h") of type Main.PlutoRunner.Bond

  1. to_index(::Main.PlutoRunner.Bond)@indices.jl:300
  2. to_index(::Vector{Dict{Main.var"workspace#3".EX, HypertextLiteral.Result}}, ::Main.PlutoRunner.Bond)@indices.jl:277
  3. to_indices@indices.jl:333[inlined]
  4. to_indices@indices.jl:325[inlined]
  5. getindex@abstractarray.jl:1241[inlined]
  6. top-level scope@Local: 1[inlined]
---
15.1 μs
Move me!
html"<span style='color: red; font-size: 1.5em'>Move me!</span>"
5.4 ms
i
i = @bind i Slider(1:length(steps))
230 ms

We can also visualize what Julia does in the forward pass on the code itself:

md"""
We can also visualize what Julia does in the forward pass on the code itself:
"""
287 μs
3y + 2 * (x - 1) 3
y=7
+ 2 * (x - 1)
(21, *)
├─ 3
└─ y=7
+ 2 * (x - 1)
(21, *)
├─ 3
└─ y=7
+ 2 * (
x=2
- 1)
(21, *)
├─ 3
└─ y=7
+ 2
(1, -)
├─ x=2
└─ 1
(21, *)
├─ 3
└─ y=7
+
(2, *)
├─ 2
└─ (1, -)
   ├─ x=2
   └─ 1
(23, +)
├─ (21, *)
│  ├─ 3
│  └─ y=7
└─ (2, *)
   ├─ 2
   └─ (1, -)
      ├─ x=2
      └─ 1
@visual_debug 3y + 2(x-1)
1.4 s
ad_steps (generic function with 1 method)
function ad_steps(x::Expr; color_fwd="red", color_bwd="green", font_size=".8em")
x = EX(x)
repr(x) = sprint(show, x; context=:compact=>true)
span_fwd = @htl "<span style='color: $color_fwd; font-size: $font_size'>"
span_bwd = @htl "<span style='color: $color_bwd; font-size: $font_size'>"

d1 = Dict(
let e = eval(i.x)
i => @htl "&ensp;$(span_fwd)$(repr(e isa Tracked ? e.val : e))</span>"
end
for i in PostOrderDFS(x) if isempty(children(i))
)
res = accumulate(Iterators.filter(x -> !isempty(children(x)), PostOrderDFS(x)); init=d1) do d,i
d = copy(d)
e = eval(i.x)
d[i] = @htl "&ensp;$(span_fwd)$(repr(e isa Tracked ? e.val : e))</span>"
d
end
pushfirst!(res, d1)
f = eval(x.x)
d = Dict{Any, Any}(f => 1)
let d1 = copy(res[end])
50.8 ms
begin
struct EX
x::Any
function EX(ex)
if Meta.isexpr(ex, :call) && ex.args[1] === :+ && length(ex.args) > 3
new(Expr(:call, :+, Expr(:call, :+, ex.args[2:end-1]...), ex.args[end]))
else
new(ex)
end
end
end
show_tree(ex::Expr) = show_tree(EX(ex))
function Base.show(io::IO, ex::EX)
Base.show_unquoted(io, Meta.isexpr(ex.x, :call) ? ex.x.args[1] : ex.x)
if Meta.isexpr(ex.x, :call) && ex.x.args[1] === :^
print(io, ex.x.args[3])
end
end
function AbstractTrees.children(ex::EX)
if Meta.isexpr(ex.x, :call)
ex.x.args[1] === :^ ? [EX(ex.x.args[2])] : EX.(ex.x.args[2:end])
else
EX[]
end
end
Base.:(==)(ex1::EX, ex2::EX) = ex1.x == ex2.x
4.3 ms
begin
struct TTREE
x
d::Dict{Any, Any}
end
function Base.show(io::IO, x::TTREE)
show(io, x.x)
print(io, " ")
show(io, MIME("text/html"), get(x.d, x.x, @htl("")))
end
AbstractTrees.children(x::TTREE) = (TTREE(i, x.d) for i in children(x.x))
end
348 ms
s2
s2 = @htl """
<link rel="stylesheet" href="https://fperucic.github.io/treant-js/Treant.css"/>
<style>
.Treant > .node {
padding: 5px; border: 2px solid #484848; border-radius: 8px;
box-sizing: unset;
min-width: fit-content;
font-size: 1.6em;
}
.Treant > .node > span {
vertical-align: middle;
}

.Treant .collapse-switch { width: 100%; height: 100%; border: none; }
.Treant .node.collapsed { background-color: var(--main-bg-color); }
.Treant .node.collapsed .collapse-switch { background: none;}
</style>

<script src="https://fperucic.github.io/treant-js/vendor/jquery.min.js"></script>
<script src="https://fperucic.github.io/treant-js/vendor/jquery.easing.js"></script>
<script src="https://fperucic.github.io/treant-js/vendor/raphael.js"></script>
95.0 μs
to_json (generic function with 1 method)
function to_json(x)
d = Dict{Symbol, Any}(
:innerHTML => sprint(AbstractTrees.printnode, x),
:children => Any[to_json(c) for c in children(x)],
#:collapsed => !isempty(children(x)),
)
end
1.1 ms
show_tree (generic function with 2 methods)
function show_tree(x; height=400)
id = gensym()
@htl """
<div id="$id" style="width:100%; height: $(height)px"> </div>
<script>
var simple_chart_config = {
chart: {
container: "#$id",

//animateOnInit: true,

node: {
collapsable: true,
},

nodeAlign: "BOTTOM",

connectors: {
type: "straight",
style: {
stroke: getComputedStyle(document.documentElement).getPropertyValue('--cm-editor-text-color')
}
},
animation: {
nodeAnimation: "easeOutBounce",
nodeSpeed: 500,
connectorsAnimation: "bounce",
connectorsSpeed: 500
},
},

695 ms
s1
s1 = html"""
<style>
p-frame-viewer {
display: inline-flex;
flex-direction: column;
}
p-frames,
p-frame-controls {
display: inline-flex;
}
p-frame-controls {
margin-top: 20px;
}
line-like {
font-size: 30px;
}
"""
33.9 μs
@visual_debug (macro with 1 method)
macro visual_debug(expr)
quote
$(esc(:(PlutoTest.@eval_step_by_step($expr)))) .|> PlutoTest.SlottedDisplay |> PlutoTest.frames |> PlutoTest.with_slotted_css
end
end
719 μs